1 The data

In smoking_status Unknown should be changed to NA.

Also, it can be ordered: never < formerly < smokes

ever_married can be recoded as 0/1 in accordance with heart_disease and hypertension

Other predictors seem to be OK

df <- read_csv("data/healthcare-dataset-stroke-data.csv", col_types = "cfdfffffddcf", na = c("Unknown", "N/A"))

# if you set smoking_status to factor in col_types, na() won't work
df$smoking_status <- as_factor(df$smoking_status)
df$smoking_status <- fct_relevel(df$smoking_status, c("never smoked", "formerly smoked", "smokes"))

# married
df$ever_married <- factor(if_else(df$ever_married == "Yes", 1, 0))

# for models working properly
df$stroke <- factor(ifelse(df$stroke == 1, "yes", "no"), levels = c("no", "yes"))

df

1.1 Descriptive statistics

Skip id column

df$id <- NULL
skimr::skim(df)
Data summary
Name df
Number of rows 5110
Number of columns 11
_______________________
Column type frequency:
factor 8
numeric 3
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
gender 0 1.0 FALSE 3 Fem: 2994, Mal: 2115, Oth: 1
hypertension 0 1.0 FALSE 2 0: 4612, 1: 498
heart_disease 0 1.0 FALSE 2 0: 4834, 1: 276
ever_married 0 1.0 FALSE 2 1: 3353, 0: 1757
work_type 0 1.0 FALSE 5 Pri: 2925, Sel: 819, chi: 687, Gov: 657
Residence_type 0 1.0 FALSE 2 Urb: 2596, Rur: 2514
smoking_status 1544 0.7 FALSE 3 nev: 1892, for: 885, smo: 789
stroke 0 1.0 FALSE 2 no: 4861, yes: 249

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
age 0 1.00 43.23 22.61 0.08 25.00 45.00 61.00 82.00 ▅▆▇▇▆
avg_glucose_level 0 1.00 106.15 45.28 55.12 77.24 91.88 114.09 271.74 ▇▃▁▁▁
bmi 201 0.96 28.89 7.85 10.30 23.50 28.10 33.10 97.60 ▇▇▁▁▁

Target ‘stroke’ is imbalanced!

‘smoking_status’ completeness rate 0.7

1.2 How many smoking_status in each target class?

df %>% group_by(stroke, smoking_status) %>% summarise(N=n())

BMI’s complete rate 0.96

1.3 How many skipped BMI in each target class?

df %>% filter(is.na(bmi)) %>% group_by(stroke) %>% summarise(N=n())

One ‘Other’ gender to be removed

df <- df %>% filter(gender != "Other")

1.4 EDA

1.4.1 Overview: a pairs plot

GGally::ggpairs(df, aes(color = stroke, alpha = 0.2, dotsize = 0.02), 
        upper = list(continuous = GGally::wrap("cor", size = 2.5)),
        diag = list(continuous = "barDiag")) +
  scale_color_brewer(palette = "Set1", direction = -1) +
  scale_fill_brewer(palette = "Set1", direction = -1)

1.4.2 In details

1.4.2.1 Stroke vs Age

ggplot(df, aes(stroke, age)) +
  geom_boxplot(aes(fill = stroke), alpha = 0.5, varwidth = T, notch = T) +
  geom_violin(aes(fill = stroke), alpha = 0.5) +
  scale_fill_brewer(palette = "Set1", direction = -1) +
  xlab("")

OBS! There are observation with age much below 20 y.o., even close to 0!

These are very young kids or babies - should we even include them in the analysis?

If so, it will be prediction for adults. Also, stroke in kids probably has very different causes compared to stroke in adults/old folk.

1.4.2.2 Stroke vs Age + Gender

ggplot(df, aes(stroke, age)) + 
  geom_violin(alpha=0.3) +
  geom_jitter(alpha=0.2, size=0.8, width = 0.15, height = 0.1, aes(color = gender)) + 
  geom_boxplot(alpha = 0.2) +
  scale_color_brewer(palette = "Set2", direction = -1)

1.4.2.3 Stroke vs Glucose

ggplot(df, aes(stroke, avg_glucose_level)) +
  geom_boxplot(aes(fill = stroke), alpha = 0.5, varwidth = T, notch = T) +
  geom_violin(aes(fill = stroke), alpha = 0.5) +
  scale_fill_brewer(palette = "Set1", direction = -1) +
  xlab("")

This average glucose level is probably the results of fasting blood sugar test

If I correctly understand this CDC information on diabetes, values greater than 126 is evidence of diabetes. But 250? Is it realistic?

1.4.2.4 Stroke vs BMI

ggplot(df, aes(stroke, bmi)) +
  geom_boxplot(aes(fill = stroke), alpha = 0.5, varwidth = T, notch = T) +
  geom_violin(aes(fill = stroke), alpha = 0.5) +
  scale_fill_brewer(palette = "Set1", direction = -1) +
  xlab("")

BMI over 40 is the 3rd class of obesity - BMI over 75 should not exist at all.

Let’s look at this weird points

1.4.2.5 Age vs BMI vs Glucose

ggplot(df, aes(age, bmi)) +
  geom_point(aes(color = stroke), alpha = 0.8, size = 0.5) +
  scale_fill_brewer(palette = "Set1", direction = -1) +
  facet_grid(rows = vars(stroke)) +
  guides(color = "none")

Patients with BMI over 75 are also very young. Suspicious.

1.4.2.6 Glucose vs Age + smoking

ggplot(df, aes(age, avg_glucose_level)) +
  geom_point(aes(color = smoking_status), alpha = 0.6, size = 1) +
  scale_fill_brewer(palette = "Set1", direction = -1) +
  facet_grid(rows = vars(stroke)) +
  guides()

OBS! Kids are mainly ‘Unknown’ smoking status; both target groups are divided into two clusters – I’m curious why.

It is not gender, nor heart disease or any other factor we have in the data set.

1.4.2.7 Age vs Smoking

ggplot(df, aes(smoking_status, age)) +
  geom_boxplot(aes(fill = stroke), alpha = 0.5, varwidth = T, notch = T) +
  scale_fill_brewer(palette = "Set1", direction = -1) +
  xlab("")

Kids are mainly non-smokers of course. Note the same two outliers of age below 20 in stroke-yes

1.4.2.8 Glucose vs BMI

ggplot(df, aes(avg_glucose_level, bmi)) +
  geom_point(aes(color = age), alpha = 0.6, size = 1) +
  scale_fill_brewer(palette = "Set1", direction = -1) +
  facet_grid(rows = vars(stroke)) +
  guides()

BMI outliers: super high BMI but super low glucose levels? How’s that possible?

Can it be a bias introduced by testing protocol misuse? Like instead of measuring glucose before eating, in some samples it was done after eating?

1.4.2.9 Stroke vs Gender

gender <- df %>% group_by(stroke, gender) %>% summarize(N=n())

ggplot(gender, aes(stroke, N)) +
  geom_bar(aes(fill=gender), alpha = 0.8, stat = "identity", position = "fill") +
  scale_fill_brewer(palette = "Set2", direction = -1) +
  ylab("proportion")

Proportions in both stroke groups are roughly the same

1.4.2.10 Stroke vs Hypertension

hyptens <- df %>% group_by(stroke, hypertension) %>% summarize(N = n())

ggplot(hyptens, aes(stroke, N)) +
  geom_bar(aes(fill = hypertension), alpha = 0.8, stat = "identity", position = "fill") +
  scale_fill_brewer(palette = "Set2", direction = -1) +
  ylab("proportion")

Hypertension occurred more often in stroke-yes

1.4.2.11 Stroke vs Heart Disease

heart <- df %>% group_by(stroke, heart_disease) %>% summarize(N=n())

ggplot(heart, aes(stroke, N)) +
  geom_bar(aes(fill = heart_disease), alpha = 0.8, stat = "identity", position = "fill") +
  scale_fill_brewer(palette = "Set2", direction = 1) +
  ylab("proportion")

1.4.2.12 Stroke vs Ever Married

married <- df %>% group_by(stroke, ever_married) %>% summarize(N=n())

ggplot(married, aes(stroke, N)) +
  geom_bar(aes(fill = ever_married), alpha = 0.8, stat = "identity", position = "fill") +
  scale_fill_brewer(palette = "Set2", direction = -1) +
  ylab("proportion")

Marriage is bad :)

1.4.2.13 Stroke vs Work Type

work <- df %>% group_by(stroke, work_type) %>% summarize(N=n())

ggplot(work, aes(stroke, N)) +
  geom_bar(aes(fill = work_type), alpha = 0.8, stat = "identity", position = "fill") +
  scale_fill_brewer(palette = "Set2", direction = 1) +
  ylab("proportion")

It’s good to be a child

1.4.2.14 Stroke vs Residence Type

residence <- df %>% group_by(stroke, Residence_type) %>% summarize(N=n())

ggplot(residence, aes(stroke, N)) +
  geom_bar(aes(fill = Residence_type), alpha = 0.8, stat = "identity", position = "fill") +
  scale_fill_brewer(palette = "Set2", direction = 1) +
  ylab("proportion")

1.4.2.15 Stroke vs Smoking

smoking <- df %>% group_by(stroke, smoking_status) %>% summarize(N=n())

ggplot(smoking, aes(stroke, N)) +
  geom_bar(aes(fill = smoking_status), alpha = 0.8, stat = "identity", position = "fill") +
  scale_fill_brewer(palette = "Set2", direction = 1) +
  ylab("proportion")

1.4.3 Conclusions

There are several suspicious outliers (like in BMI and glucose). I can safely remove BMI > 75, maybe even BMI > 60 (Remember that BMI > 40 is the highest class of obesity).

What is less safe - it’s removing non-adults (younger than 20 y.o.). On one hand they provide valid information (age is very important predictor), on the other hand their stroke cases are really sus and a lot of predictors do not have sense (or are obvious NAs) for non-adults (like smoking, marriage status, employment type, residence type etc.); model-based imputation of smoking_status in children doesn’t have sense as well, I should rather replace with “never smoked”.

Since, modelling using all predictors and observations has given me very moderate results (TPR = 1 comes with very high FPR and very low probability cutoff close to 0), I will try various trimming of the data.

1.5 Trimming

Leave ‘no kinds’ version

df_trim <- df %>% filter(bmi <= 60 )

skimr::skim(df_trim)
Data summary
Name df_trim
Number of rows 4895
Number of columns 11
_______________________
Column type frequency:
factor 8
numeric 3
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
gender 0 1.0 FALSE 2 Fem: 2888, Mal: 2007, Oth: 0
hypertension 0 1.0 FALSE 2 0: 4449, 1: 446
heart_disease 0 1.0 FALSE 2 0: 4652, 1: 243
ever_married 0 1.0 FALSE 2 1: 3193, 0: 1702
work_type 0 1.0 FALSE 5 Pri: 2798, Sel: 774, chi: 671, Gov: 630
Residence_type 0 1.0 FALSE 2 Urb: 2485, Rur: 2410
smoking_status 1479 0.7 FALSE 3 nev: 1847, for: 836, smo: 733
stroke 0 1.0 FALSE 2 no: 4686, yes: 209

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
age 0 1 42.87 22.57 0.08 25.00 44.00 60.00 82.00 ▅▆▇▇▆
avg_glucose_level 0 1 105.31 44.42 55.12 77.08 91.68 113.46 271.74 ▇▃▁▁▁
bmi 0 1 28.79 7.56 10.30 23.50 28.00 33.00 59.70 ▂▇▅▁▁

BMI is complete, in total approx. 2000 observations are gone

1.6 Imputation

Using package mice

It uses polr - proportional odds model - for smoking_status and pmm - predictive mean matching - for bmi

1.6.1 Run imputation

library(mice)

imp_mice <- mice(df_trim)
## 
##  iter imp variable
##   1   1  smoking_status
##   1   2  smoking_status
##   1   3  smoking_status
##   1   4  smoking_status
##   1   5  smoking_status
##   2   1  smoking_status
##   2   2  smoking_status
##   2   3  smoking_status
##   2   4  smoking_status
##   2   5  smoking_status
##   3   1  smoking_status
##   3   2  smoking_status
##   3   3  smoking_status
##   3   4  smoking_status
##   3   5  smoking_status
##   4   1  smoking_status
##   4   2  smoking_status
##   4   3  smoking_status
##   4   4  smoking_status
##   4   5  smoking_status
##   5   1  smoking_status
##   5   2  smoking_status
##   5   3  smoking_status
##   5   4  smoking_status
##   5   5  smoking_status
df_imp <- complete(imp_mice)

Number of NAs in BMI: 0

Number of NAs in Smoking: 0

1.6.2 Check distributions

1.6.2.1 BMI

bmi_imp_comp <- bind_rows(select(df_trim, bmi, stroke) %>% mutate(type = rep("original", nrow(df_trim))),
          select(df_imp, bmi, stroke) %>% mutate(type = rep("imputed", nrow(df_imp))))

ggplot(bmi_imp_comp, aes(bmi)) +
  geom_histogram(aes(fill = type), alpha = 0.8) +
  facet_grid(cols = vars(stroke))

Means have not changed, which is good, I suppose.

1.6.2.2 Smoking

smoke_imp_comp <- bind_rows(select(df_trim, smoking_status, stroke) %>% mutate(type = rep("original", nrow(df_trim))),
          select(df_imp, smoking_status, stroke) %>% mutate(type = rep("imputed", nrow(df_imp))))

ggplot(smoke_imp_comp, aes(smoking_status)) +
  geom_bar(aes(fill=type), alpha=0.8, position="dodge") +
  facet_grid(cols = vars(stroke)) +
  xlab("")+
  theme(axis.text.x = element_text(angle=45, vjust = 0.5))

Counts increased proportionally in all Smoking groups

1.7 Scaling & Normalization

Scale numeric features (including imputed BMI)

# use caret::preProcess()
# preProcValues <- preProcess(training, method = c("center", "scale"))

df_scaled <- df_imp %>% 
  select(avg_glucose_level, age, bmi) %>% 
  scale() %>% 
  data.frame()

1.8 Make Dummies

I’ve decided to omit smoking_status completely - it won’t be dummified

# select vars
to_dum <- df_imp %>% select(gender, work_type, Residence_type, smoking_status)
# make an obj
dummies <- dummyVars(~ ., data = to_dum)
# apply it
df_dummy <- data.frame(predict(dummies, newdata = to_dum))

head(df_dummy)

1.9 Join scaled and dummies and the rest

df_proc <- bind_cols(df_scaled, df_dummy, select(df_trim, hypertension, heart_disease, ever_married, stroke))
head(df_proc)

2 Modelling

2.1 Basic parameters

ROC-optimization is better when data is imbalanced

Kappa-optimization is also good

# for ROC
fit_ctrl_roc <- trainControl(## 5-fold CV
                           method = "repeatedcv",
                           number = 10,
                           repeats = 10, 
                           allowParallel = T,
                           classProbs = T,
                           summaryFunction = twoClassSummary)
# for kappa
fit_ctrl_kp <- trainControl(## 5-fold CV
                           method = "repeatedcv",
                           number = 10,
                           repeats = 10, 
                           allowParallel = T)

2.2 Split data

Imbalanced data - use SMOTE to create training data set, but not testing data set

set.seed(1234)
sample_set <- createDataPartition(y = df_proc$stroke, p = .75, list = FALSE)
df_train <- df_proc[sample_set,]
df_test <- df_proc[-sample_set,]

# DMwR::SMOTE for imbalanced data: over=225 and under=150 give me 1:1 ratio
df_train_smote <- SMOTE(stroke ~ ., data.frame(df_train), perc.over = 1725, perc.under = 106)

df_train_smote %>% group_by(stroke) %>% summarise(N=n())

3 Random Forest

3.1 Training and validation

# start a cluster
library(doParallel)

cl <- makePSOCKcluster(THREADS)
registerDoParallel(cl)

3.1.1 Kappa-optimized

For imbalanced classes

set.seed(123)

fit_rf <- train(stroke ~ ., 
                 data = df_train_smote, 
                 metric = "Kappa", 
                 method = "rf", 
                 trControl = fit_ctrl_kp,
                 tuneGrid = expand.grid(.mtry = seq(2, 19, 1)),
                 verbosity = 0,
                 verbose = FALSE)

fit_rf
## Random Forest 
## 
## 5655 samples
##   19 predictor
##    2 classes: 'no', 'yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 5089, 5089, 5090, 5090, 5090, 5089, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.9552956  0.9105899
##    3    0.9623692  0.9247371
##    4    0.9641018  0.9282022
##    5    0.9653223  0.9306432
##    6    0.9661004  0.9321995
##    7    0.9659764  0.9319515
##    8    0.9660295  0.9320578
##    9    0.9659237  0.9318462
##   10    0.9657999  0.9315987
##   11    0.9654990  0.9309968
##   12    0.9648801  0.9297591
##   13    0.9641372  0.9282732
##   14    0.9641193  0.9282374
##   15    0.9633941  0.9267869
##   16    0.9631645  0.9263277
##   17    0.9626339  0.9252666
##   18    0.9622270  0.9244526
##   19    0.9617851  0.9235688
## 
## Kappa was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 6.

3.1.2 ROC-optimized

#cl <- makePSOCKcluster(THREADS)
#registerDoParallel(cl)

set.seed(120)

fit_rf_roc <- train(stroke ~ ., 
                 data = df_train_smote, 
                 metric = "ROC", 
                 method = "rf", 
                 trControl = fit_ctrl_roc,
                 tuneGrid = expand.grid(.mtry = seq(2, 19, 1)),
                 verbosity = 0,
                 verbose = FALSE)
#stopCluster(cl)

fit_rf_roc
## Random Forest 
## 
## 5655 samples
##   19 predictor
##    2 classes: 'no', 'yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 5089, 5090, 5089, 5089, 5089, 5090, ... 
## Resampling results across tuning parameters:
## 
##   mtry  ROC        Sens       Spec     
##    2    0.9872081  0.9651120  0.9444082
##    3    0.9902507  0.9782975  0.9456825
##    4    0.9914815  0.9834923  0.9448328
##    5    0.9921455  0.9860024  0.9456471
##    6    0.9924712  0.9854018  0.9472749
##    7    0.9926183  0.9850835  0.9480889
##    8    0.9926826  0.9843061  0.9481238
##    9    0.9926634  0.9839525  0.9484774
##   10    0.9925920  0.9832808  0.9485486
##   11    0.9925968  0.9826797  0.9484064
##   12    0.9924771  0.9820075  0.9479463
##   13    0.9923007  0.9813362  0.9474152
##   14    0.9922675  0.9811241  0.9470975
##   15    0.9921257  0.9801339  0.9471338
##   16    0.9919006  0.9797105  0.9466030
##   17    0.9918002  0.9791095  0.9466733
##   18    0.9916376  0.9784737  0.9460016
##   19    0.9915187  0.9785441  0.9455416
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 8.

3.2 Features importance

3.2.1 Kappa-optimized model

imp_vars_rf <- varImp(fit_rf)

plot(imp_vars_rf, main = "Variable Importance with RF")

3.2.2 ROC-optimized model

it’s the same

3.3 Testing

3.3.1 ROC & AUC

a function for roc-stuff

get_roc <- function(fit.obj, testing.df){
  pred_prob <- predict.train(fit.obj, newdata = testing.df, type = "prob")
  pred_roc <- prediction(predictions = pred_prob$yes, labels = testing.df$stroke)
  perf_roc <- performance(pred_roc, measure = "tpr", x.measure = "fpr")
  return(list(perf_roc, pred_roc))
}

3.3.1.1 ROC-curve for kappa-optimized model

# calculate ROC
perf_pred <- get_roc(fit_rf, df_test)
perf_rf <- perf_pred[[1]]
pred_rf <- perf_pred[[2]]

# take AUC 
auc_rf <- round(unlist(slot(performance(pred_rf, measure = "auc"), "y.values")), 3)

# plot
plot(perf_rf, main = "RF-k ROC curve", col = "steelblue", lwd = 3)
abline(a = 0, b = 1, lwd = 3, lty = 2, col = 1)
legend(x = 0.7, y = 0.3, legend = paste0("AUC = ", auc_rf))

3.3.1.2 ROC-curve for ROC-optimized model

# calculate ROC
perf_pred_roc <- get_roc(fit_rf_roc, df_test)
perf_rf_roc <- perf_pred_roc[[1]]
pred_rf_roc <- perf_pred_roc[[2]]

# take AUC 
auc_rf_roc <- round(unlist(slot(performance(pred_rf_roc, measure = "auc"), "y.values")), 3)

# plot
plot(perf_rf_roc, main = "RF-r ROC curve", col = "steelblue", lwd = 3)
abline(a = 0, b = 1, lwd = 3, lty = 2, col = 1)
legend(x = 0.7, y = 0.3, legend = paste0("AUC = ", auc_rf_roc))

So, we can adjust TPR/FPR cutoff to predict all strokes

3.3.2 TPR, FPR vs Probability cutoff

At which probability cut-off, you’ll get TPR = 1.0?

# use pred_rf (pred_roc) object
plot(performance(pred_rf, measure = "tpr", x.measure = "cutoff"),
     col="steelblue", 
     ylab = "Rate", 
     xlab="Probability cutoff")

plot(performance(pred_rf, measure = "fpr", x.measure = "cutoff"), 
     add = T, col = "red")

legend(x = 0.6,y = 0.7, c("TPR (Recall)", "FPR (1-Spec)"), 
       lty = 1, col =c('steelblue', 'red'), bty = 'n', cex = 1, lwd = 2)

#abline(v = 0.02, lwd = 2, lty=6)

title("RF-k")

Vertical line at cutoff = 0.02 designates maximum TPR and maximum FPR. Ideal cutoff should be to the left of this line

# use pred_rf (pred_roc) object
plot(performance(pred_rf_roc, measure = "tpr", x.measure = "cutoff"),
     col = "steelblue", 
     ylab = "Rate", 
     xlab = "Probability cutoff")

plot(performance(pred_rf_roc, measure = "fpr", x.measure = "cutoff"), 
     add = T, col = "red")

legend(x = 0.6,y = 0.7, c("TPR (Recall)", "FPR (1-Spec)"), 
       lty = 1, col = c('steelblue', 'red'), bty = 'n', cex = 1, lwd = 2)

#abline(v = 0.02, lwd = 2, lty=6)

title("RF-r")

Vertical line at 0.02

3.3.3 Confusion matrix

3.3.3.1 Kappa-optimized

Using desired cut-off: we want to maximize TPR (Sensitivity, Recall)!

According to the TPR/FPR plot (above) the optimal cutoff is

# predict probabilities
pred_prob_rf <- predict(fit_rf, newdata=df_test, type = "prob")

# choose your cut-off
cutoff = 0.01

# turn probabilities into classes
pred_class_rf <- ifelse(pred_prob_rf$yes > cutoff, "yes", "no")

pred_class_rf <- as.factor(pred_class_rf)

cm_rf <- confusionMatrix(data = pred_class_rf, 
                reference = df_test$stroke,
                mode = "everything",
                positive = "yes")

cm_rf
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  no yes
##        no  488   3
##        yes 683  49
##                                          
##                Accuracy : 0.4391         
##                  95% CI : (0.411, 0.4674)
##     No Information Rate : 0.9575         
##     P-Value [Acc > NIR] : 1              
##                                          
##                   Kappa : 0.0495         
##                                          
##  Mcnemar's Test P-Value : <2e-16         
##                                          
##             Sensitivity : 0.94231        
##             Specificity : 0.41674        
##          Pos Pred Value : 0.06694        
##          Neg Pred Value : 0.99389        
##               Precision : 0.06694        
##                  Recall : 0.94231        
##                      F1 : 0.12500        
##              Prevalence : 0.04252        
##          Detection Rate : 0.04007        
##    Detection Prevalence : 0.59853        
##       Balanced Accuracy : 0.67952        
##                                          
##        'Positive' Class : yes            
## 

3.3.3.2 ROC-optimized

# predict probabilities
pred_prob_rf_roc <- predict(fit_rf_roc, newdata = df_test, type = "prob")

# choose your cut-off
cutoff = 0.01

# turn probabilities into classes
pred_class_rf_roc <- ifelse(pred_prob_rf_roc$yes > cutoff, "yes", "no")

pred_class_rf_roc <- as.factor(pred_class_rf_roc)

cm_rf <- confusionMatrix(data = pred_class_rf_roc, 
                reference = df_test$stroke,
                mode = "everything",
                positive = "yes")

cm_rf
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  no yes
##        no  505   4
##        yes 666  48
##                                          
##                Accuracy : 0.4522         
##                  95% CI : (0.424, 0.4806)
##     No Information Rate : 0.9575         
##     P-Value [Acc > NIR] : 1              
##                                          
##                   Kappa : 0.05           
##                                          
##  Mcnemar's Test P-Value : <2e-16         
##                                          
##             Sensitivity : 0.92308        
##             Specificity : 0.43126        
##          Pos Pred Value : 0.06723        
##          Neg Pred Value : 0.99214        
##               Precision : 0.06723        
##                  Recall : 0.92308        
##                      F1 : 0.12533        
##              Prevalence : 0.04252        
##          Detection Rate : 0.03925        
##    Detection Prevalence : 0.58381        
##       Balanced Accuracy : 0.67717        
##                                          
##        'Positive' Class : yes            
## 

4 AdaBoost

4.1 Training and validation

4.1.1 ROC-optimized

10-fold CV

set.seed(122)

#cl <- makePSOCKcluster(THREADS)
#registerDoParallel(cl)

fit_adb <- train(stroke ~ ., 
                 data = df_train_smote, 
                 metric = "Kappa", 
                 method = "AdaBoost.M1", 
                 trControl = fit_ctrl_roc,
                 tuneLength = 10,
                 verbosity = 0,
                 verbose = FALSE)
# coeflearn=Freund was chosen by automatic grid search, mfinal choice comes from there too

# stop CLuster
stopCluster(cl)

fit_adb
## AdaBoost.M1 
## 
## 5655 samples
##   19 predictor
##    2 classes: 'no', 'yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times) 
## Summary of sample sizes: 5089, 5090, 5089, 5089, 5090, 5090, ... 
## Resampling results across tuning parameters:
## 
##   coeflearn  maxdepth  mfinal  ROC        Sens       Spec     
##   Breiman     1         50     0.9777189  0.8724294  0.9555591
##   Breiman     1        100     0.9843121  0.9122684  0.9531884
##   Breiman     1        150     0.9859259  0.9262983  0.9529057
##   Breiman     1        200     0.9859275  0.9225885  0.9526584
##   Breiman     1        250     0.9859236  0.9214240  0.9529776
##   Breiman     1        300     0.9859858  0.9209279  0.9544274
##   Breiman     1        350     0.9862647  0.9261940  0.9531184
##   Breiman     1        400     0.9863285  0.9273221  0.9529416
##   Breiman     1        450     0.9864843  0.9293386  0.9532254
##   Breiman     1        500     0.9865636  0.9293755  0.9527651
##   Breiman     2         50     0.9865974  0.9454931  0.9468194
##   Breiman     2        100     0.9872408  0.9502994  0.9456874
##   Breiman     2        150     0.9875402  0.9522092  0.9453688
##   Breiman     2        200     0.9876916  0.9544369  0.9447324
##   Breiman     2        250     0.9878425  0.9552151  0.9445559
##   Breiman     2        300     0.9880079  0.9575828  0.9440597
##   Breiman     2        350     0.9880121  0.9567360  0.9443072
##   Breiman     2        400     0.9880047  0.9573723  0.9448736
##   Breiman     2        450     0.9881015  0.9585388  0.9442721
##   Breiman     2        500     0.9881340  0.9579725  0.9444138
##   Breiman     3         50     0.9874795  0.9526360  0.9439521
##   Breiman     3        100     0.9883973  0.9627099  0.9433153
##   Breiman     3        150     0.9886993  0.9658201  0.9432449
##   Breiman     3        200     0.9889822  0.9668811  0.9440587
##   Breiman     3        250     0.9891082  0.9675162  0.9440932
##   Breiman     3        300     0.9891215  0.9685071  0.9442710
##   Breiman     3        350     0.9891442  0.9690730  0.9443422
##   Breiman     3        400     0.9892311  0.9695326  0.9443412
##   Breiman     3        450     0.9893029  0.9697444  0.9447652
##   Breiman     3        500     0.9893214  0.9711941  0.9440235
##   Breiman     4         50     0.9884225  0.9635218  0.9447303
##   Breiman     4        100     0.9891446  0.9690706  0.9455072
##   Breiman     4        150     0.9895984  0.9732420  0.9453673
##   Breiman     4        200     0.9897733  0.9743384  0.9461103
##   Breiman     4        250     0.9899460  0.9760350  0.9462162
##   Breiman     4        300     0.9900248  0.9755051  0.9465704
##   Breiman     4        350     0.9901296  0.9762826  0.9469230
##   Breiman     4        400     0.9901746  0.9763885  0.9468516
##   Breiman     4        450     0.9902223  0.9770248  0.9466040
##   Breiman     4        500     0.9902260  0.9767418  0.9468871
##   Breiman     5         50     0.9891837  0.9674117  0.9454386
##   Breiman     5        100     0.9899914  0.9750808  0.9466410
##   Breiman     5        150     0.9902461  0.9768834  0.9477379
##   Breiman     5        200     0.9904293  0.9769535  0.9483038
##   Breiman     5        250     0.9906032  0.9768477  0.9485158
##   Breiman     5        300     0.9907113  0.9767421  0.9489765
##   Breiman     5        350     0.9907313  0.9765649  0.9488338
##   Breiman     5        400     0.9908606  0.9768476  0.9486931
##   Breiman     5        450     0.9909274  0.9769183  0.9484799
##   Breiman     5        500     0.9909489  0.9771310  0.9487630
##   Breiman     6         50     0.9902362  0.9740194  0.9473140
##   Breiman     6        100     0.9909817  0.9768469  0.9485512
##   Breiman     6        150     0.9913142  0.9769879  0.9492235
##   Breiman     6        200     0.9914806  0.9772006  0.9494015
##   Breiman     6        250     0.9915893  0.9765650  0.9497897
##   Breiman     6        300     0.9916855  0.9776963  0.9498949
##   Breiman     6        350     0.9916391  0.9777665  0.9501069
##   Breiman     6        400     0.9917391  0.9773076  0.9498951
##   Breiman     6        450     0.9917564  0.9773429  0.9499302
##   Breiman     6        500     0.9917684  0.9767772  0.9502841
##   Breiman     7         50     0.9909562  0.9749389  0.9497554
##   Breiman     7        100     0.9916580  0.9757877  0.9511340
##   Breiman     7        150     0.9920635  0.9764230  0.9509576
##   Breiman     7        200     0.9923196  0.9761398  0.9511346
##   Breiman     7        250     0.9923537  0.9768828  0.9517009
##   Breiman     7        300     0.9924595  0.9770243  0.9515592
##   Breiman     7        350     0.9924834  0.9767413  0.9513111
##   Breiman     7        400     0.9925652  0.9765296  0.9513821
##   Breiman     7        450     0.9925628  0.9768476  0.9516305
##   Breiman     7        500     0.9925739  0.9762825  0.9517363
##   Breiman     8         50     0.9917185  0.9755745  0.9499677
##   Breiman     8        100     0.9924765  0.9759286  0.9520557
##   Breiman     8        150     0.9929143  0.9760696  0.9520203
##   Breiman     8        200     0.9931415  0.9759988  0.9517723
##   Breiman     8        250     0.9932598  0.9763523  0.9520551
##   Breiman     8        300     0.9933312  0.9761049  0.9523023
##   Breiman     8        350     0.9933860  0.9762460  0.9523733
##   Breiman     8        400     0.9934063  0.9768469  0.9526918
##   Breiman     8        450     0.9934732  0.9766357  0.9527978
##   Breiman     8        500     0.9934860  0.9765647  0.9525853
##   Breiman     9         50     0.9926802  0.9756809  0.9519128
##   Breiman     9        100     0.9933290  0.9766003  0.9537180
##   Breiman     9        150     0.9937317  0.9763876  0.9536822
##   Breiman     9        200     0.9938726  0.9760351  0.9533989
##   Breiman     9        250     0.9939771  0.9768123  0.9539657
##   Breiman     9        300     0.9940273  0.9765648  0.9535759
##   Breiman     9        350     0.9940349  0.9766001  0.9541776
##   Breiman     9        400     0.9940754  0.9763528  0.9541065
##   Breiman     9        450     0.9941410  0.9764590  0.9542127
##   Breiman     9        500     0.9941363  0.9766709  0.9537532
##   Breiman    10         50     0.9932476  0.9755394  0.9532930
##   Breiman    10        100     0.9938992  0.9763177  0.9537173
##   Breiman    10        150     0.9941438  0.9760702  0.9546733
##   Breiman    10        200     0.9943115  0.9761054  0.9547787
##   Breiman    10        250     0.9944159  0.9761764  0.9549561
##   Breiman    10        300     0.9945443  0.9763885  0.9547788
##   Breiman    10        350     0.9945730  0.9767063  0.9548851
##   Breiman    10        400     0.9946096  0.9764945  0.9549561
##   Breiman    10        450     0.9946020  0.9763885  0.9549209
##   Breiman    10        500     0.9945937  0.9765295  0.9550627
##   Freund      1         50     0.9853173  0.9206446  0.9506027
##   Freund      1        100     0.9857265  0.9200073  0.9518091
##   Freund      1        150     0.9864466  0.9308601  0.9507819
##   Freund      1        200     0.9867348  0.9343229  0.9500752
##   Freund      1        250     0.9869161  0.9369405  0.9496855
##   Freund      1        300     0.9869307  0.9378588  0.9497565
##   Freund      1        350     0.9870315  0.9386368  0.9496860
##   Freund      1        400     0.9870779  0.9389191  0.9495794
##   Freund      1        450     0.9870925  0.9382109  0.9493675
##   Freund      1        500     0.9871654  0.9401911  0.9496153
##   Freund      2         50     0.9869236  0.9559602  0.9415830
##   Freund      2        100     0.9875503  0.9576881  0.9432123
##   Freund      2        150     0.9879051  0.9626398  0.9417243
##   Freund      2        200     0.9881565  0.9635240  0.9419717
##   Freund      2        250     0.9883170  0.9644769  0.9429266
##   Freund      2        300     0.9884519  0.9653962  0.9424321
##   Freund      2        350     0.9885248  0.9655382  0.9425012
##   Freund      2        400     0.9885787  0.9663158  0.9422179
##   Freund      2        450     0.9885543  0.9668806  0.9425721
##   Freund      2        500     0.9886419  0.9673402  0.9428205
##   Freund      3         50     0.9877998  0.9608367  0.9421482
##   Freund      3        100     0.9885627  0.9681168  0.9437740
##   Freund      3        150     0.9888895  0.9708024  0.9441638
##   Freund      3        200     0.9891174  0.9720745  0.9448365
##   Freund      3        250     0.9892851  0.9739130  0.9455442
##   Freund      3        300     0.9893241  0.9743380  0.9450479
##   Freund      3        350     0.9894307  0.9753983  0.9450130
##   Freund      3        400     0.9894889  0.9758929  0.9459686
##   Freund      3        450     0.9895807  0.9759635  0.9455074
##   Freund      3        500     0.9896117  0.9769881  0.9461087
##   Freund      4         50     0.9889018  0.9673042  0.9450491
##   Freund      4        100     0.9895490  0.9740200  0.9460752
##   Freund      4        150     0.9899039  0.9735253  0.9467819
##   Freund      4        200     0.9898918  0.9744447  0.9470299
##   Freund      4        250     0.9900201  0.9747275  0.9480917
##   Freund      4        300     0.9900004  0.9743376  0.9473125
##   Freund      4        350     0.9900579  0.9751860  0.9473835
##   Freund      4        400     0.9900991  0.9748679  0.9477366
##   Freund      4        450     0.9901194  0.9745847  0.9477717
##   Freund      4        500     0.9901630  0.9744084  0.9477013
##   Freund      5         50     0.9895523  0.9724283  0.9481964
##   Freund      5        100     0.9903765  0.9739136  0.9485863
##   Freund      5        150     0.9905173  0.9740197  0.9485869
##   Freund      5        200     0.9906549  0.9743380  0.9490110
##   Freund      5        250     0.9907527  0.9738782  0.9495063
##   Freund      5        300     0.9907806  0.9747976  0.9492944
##   Freund      5        350     0.9909475  0.9745148  0.9491527
##   Freund      5        400     0.9909815  0.9744795  0.9488338
##   Freund      5        450     0.9910496  0.9750449  0.9494343
##   Freund      5        500     0.9910783  0.9744082  0.9496468
##   Freund      6         50     0.9903656  0.9725705  0.9488705
##   Freund      6        100     0.9910913  0.9736660  0.9503208
##   Freund      6        150     0.9915119  0.9742314  0.9511707
##   Freund      6        200     0.9917769  0.9741611  0.9514177
##   Freund      6        250     0.9919117  0.9742670  0.9520191
##   Freund      6        300     0.9920402  0.9745143  0.9515233
##   Freund      6        350     0.9920691  0.9739135  0.9514183
##   Freund      6        400     0.9921408  0.9741959  0.9519135
##   Freund      6        450     0.9921904  0.9746202  0.9521963
##   Freund      6        500     0.9922505  0.9746201  0.9518415
##   Freund      7         50     0.9913120  0.9724284  0.9514156
##   Freund      7        100     0.9922303  0.9730646  0.9525131
##   Freund      7        150     0.9926749  0.9740897  0.9532569
##   Freund      7        200     0.9928864  0.9745141  0.9535395
##   Freund      7        250     0.9929598  0.9747261  0.9532209
##   Freund      7        300     0.9930487  0.9740899  0.9537523
##   Freund      7        350     0.9931199  0.9745845  0.9537516
##   Freund      7        400     0.9931741  0.9748682  0.9541052
##   Freund      7        450     0.9932382  0.9751511  0.9543531
##   Freund      7        500     0.9932645  0.9748682  0.9542115
##   Freund      8         50     0.9922808  0.9731006  0.9536456
##   Freund      8        100     0.9931922  0.9741256  0.9547073
##   Freund      8        150     0.9935307  0.9740898  0.9550266
##   Freund      8        200     0.9937153  0.9744438  0.9549214
##   Freund      8        250     0.9938616  0.9742668  0.9552040
##   Freund      8        300     0.9938975  0.9747265  0.9557352
##   Freund      8        350     0.9939477  0.9747617  0.9555229
##   Freund      8        400     0.9939748  0.9747270  0.9558768
##   Freund      8        450     0.9939759  0.9744087  0.9557348
##   Freund      8        500     0.9940409  0.9745500  0.9556993
##   Freund      9         50     0.9929256  0.9739487  0.9549550
##   Freund      9        100     0.9936573  0.9748322  0.9556992
##   Freund      9        150     0.9939422  0.9749035  0.9557344
##   Freund      9        200     0.9941174  0.9745854  0.9562295
##   Freund      9        250     0.9941936  0.9750801  0.9567251
##   Freund      9        300     0.9942773  0.9752922  0.9566193
##   Freund      9        350     0.9943566  0.9751153  0.9566544
##   Freund      9        400     0.9944298  0.9752923  0.9568321
##   Freund      9        450     0.9944423  0.9750096  0.9569378
##   Freund      9        500     0.9944597  0.9749743  0.9565836
##   Freund     10         50     0.9932535  0.9742673  0.9556280
##   Freund     10        100     0.9938104  0.9750457  0.9561946
##   Freund     10        150     0.9941267  0.9752227  0.9563713
##   Freund     10        200     0.9942849  0.9753991  0.9566192
##   Freund     10        250     0.9943845  0.9757521  0.9567606
##   Freund     10        300     0.9945301  0.9756102  0.9567254
##   Freund     10        350     0.9945800  0.9756457  0.9570087
##   Freund     10        400     0.9946073  0.9756454  0.9568664
##   Freund     10        450     0.9946510  0.9754690  0.9567249
##   Freund     10        500     0.9946462  0.9755396  0.9567249
##   Zhu         1         50     0.9848674  0.9171448  0.9506403
##   Zhu         1        100     0.9858034  0.9275384  0.9510284
##   Zhu         1        150     0.9864642  0.9330861  0.9499687
##   Zhu         1        200     0.9867275  0.9333325  0.9504295
##   Zhu         1        250     0.9868298  0.9342506  0.9500752
##   Zhu         1        300     0.9869588  0.9370441  0.9498280
##   Zhu         1        350     0.9869826  0.9354179  0.9502170
##   Zhu         1        400     0.9870457  0.9370091  0.9496510
##   Zhu         1        450     0.9870464  0.9376448  0.9494745
##   Zhu         1        500     0.9870759  0.9368683  0.9494744
##   Zhu         2         50     0.9869956  0.9535197  0.9431397
##   Zhu         2        100     0.9877857  0.9597756  0.9424662
##   Zhu         2        150     0.9880038  0.9608720  0.9426102
##   Zhu         2        200     0.9882471  0.9628527  0.9433518
##   Zhu         2        250     0.9884515  0.9655016  0.9429260
##   Zhu         2        300     0.9884816  0.9643354  0.9427499
##   Zhu         2        350     0.9884639  0.9663866  0.9424314
##   Zhu         2        400     0.9885498  0.9665619  0.9430339
##   Zhu         2        450     0.9886025  0.9669500  0.9430698
##   Zhu         2        500     0.9886317  0.9675863  0.9432462
##   Zhu         3         50     0.9878142  0.9604803  0.9420754
##   Zhu         3        100     0.9886358  0.9662078  0.9430686
##   Zhu         3        150     0.9890801  0.9702005  0.9440570
##   Zhu         3        200     0.9892089  0.9714035  0.9445170
##   Zhu         3        250     0.9893610  0.9722164  0.9448015
##   Zhu         3        300     0.9894536  0.9732767  0.9449060
##   Zhu         3        350     0.9895536  0.9738782  0.9447656
##   Zhu         3        400     0.9896059  0.9738426  0.9453307
##   Zhu         3        450     0.9896768  0.9747969  0.9452965
##   Zhu         3        500     0.9896966  0.9760344  0.9456850
##   Zhu         4         50     0.9888058  0.9682937  0.9454381
##   Zhu         4        100     0.9896389  0.9732758  0.9475600
##   Zhu         4        150     0.9899969  0.9743011  0.9481608
##   Zhu         4        200     0.9902162  0.9746192  0.9478429
##   Zhu         4        250     0.9902468  0.9752204  0.9480545
##   Zhu         4        300     0.9902258  0.9743724  0.9476658
##   Zhu         4        350     0.9902822  0.9745489  0.9477713
##   Zhu         4        400     0.9903138  0.9745133  0.9483020
##   Zhu         4        450     0.9903504  0.9747965  0.9483727
##   Zhu         4        500     0.9903690  0.9736656  0.9481955
##   Zhu         5         50     0.9894375  0.9722889  0.9484446
##   Zhu         5        100     0.9902144  0.9732436  0.9489038
##   Zhu         5        150     0.9906025  0.9740211  0.9494355
##   Zhu         5        200     0.9907409  0.9744450  0.9492226
##   Zhu         5        250     0.9907783  0.9739145  0.9492229
##   Zhu         5        300     0.9908307  0.9740550  0.9489756
##   Zhu         5        350     0.9908354  0.9735606  0.9498955
##   Zhu         5        400     0.9908813  0.9737723  0.9498957
##   Zhu         5        450     0.9908890  0.9744794  0.9501078
##   Zhu         5        500     0.9909103  0.9747262  0.9499311
##   Zhu         6         50     0.9907319  0.9726410  0.9501423
##   Zhu         6        100     0.9914610  0.9736661  0.9512403
##   Zhu         6        150     0.9918256  0.9739492  0.9513116
##   Zhu         6        200     0.9920702  0.9745503  0.9517009
##   Zhu         6        250     0.9922090  0.9743379  0.9513112
##   Zhu         6        300     0.9922127  0.9744442  0.9519124
##   Zhu         6        350     0.9922410  0.9752566  0.9519834
##   Zhu         6        400     0.9922634  0.9749038  0.9518060
##   Zhu         6        450     0.9922311  0.9747622  0.9517709
##   Zhu         6        500     0.9922906  0.9748682  0.9523731
##   Zhu         7         50     0.9916123  0.9724632  0.9511344
##   Zhu         7        100     0.9924240  0.9742670  0.9522314
##   Zhu         7        150     0.9927288  0.9742314  0.9532227
##   Zhu         7        200     0.9929402  0.9744792  0.9532578
##   Zhu         7        250     0.9930734  0.9748331  0.9534343
##   Zhu         7        300     0.9931491  0.9749389  0.9537185
##   Zhu         7        350     0.9931682  0.9751155  0.9538592
##   Zhu         7        400     0.9932087  0.9749736  0.9542479
##   Zhu         7        450     0.9932393  0.9749390  0.9540711
##   Zhu         7        500     0.9932703  0.9748328  0.9543192
##   Zhu         8         50     0.9923216  0.9728893  0.9530451
##   Zhu         8        100     0.9930361  0.9743032  0.9541070
##   Zhu         8        150     0.9934584  0.9753278  0.9543881
##   Zhu         8        200     0.9936660  0.9754338  0.9550614
##   Zhu         8        250     0.9938263  0.9750446  0.9555215
##   Zhu         8        300     0.9938891  0.9750452  0.9553087
##   Zhu         8        350     0.9939491  0.9752220  0.9558395
##   Zhu         8        400     0.9939407  0.9749394  0.9556626
##   Zhu         8        450     0.9939583  0.9750803  0.9555921
##   Zhu         8        500     0.9940433  0.9751159  0.9554869
##   Zhu         9         50     0.9929050  0.9745512  0.9546728
##   Zhu         9        100     0.9936931  0.9752214  0.9558405
##   Zhu         9        150     0.9940190  0.9751504  0.9560526
##   Zhu         9        200     0.9942620  0.9754341  0.9560170
##   Zhu         9        250     0.9943982  0.9752220  0.9562651
##   Zhu         9        300     0.9944371  0.9752567  0.9562652
##   Zhu         9        350     0.9944998  0.9753633  0.9561239
##   Zhu         9        400     0.9945278  0.9754693  0.9562296
##   Zhu         9        450     0.9945478  0.9752220  0.9560526
##   Zhu         9        500     0.9945793  0.9752930  0.9561233
##   Zhu        10         50     0.9931527  0.9743376  0.9547788
##   Zhu        10        100     0.9938406  0.9757512  0.9561944
##   Zhu        10        150     0.9941418  0.9756809  0.9563360
##   Zhu        10        200     0.9942761  0.9759288  0.9561940
##   Zhu        10        250     0.9944319  0.9759638  0.9559823
##   Zhu        10        300     0.9945351  0.9758231  0.9563007
##   Zhu        10        350     0.9945528  0.9759999  0.9564777
##   Zhu        10        400     0.9945923  0.9761408  0.9565838
##   Zhu        10        450     0.9946541  0.9762119  0.9564777
##   Zhu        10        500     0.9946433  0.9759995  0.9565485
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were mfinal = 450, maxdepth = 10
##  and coeflearn = Zhu.

4.2 Testing

4.2.1 ROC curve

# calculate ROC
perf_pred_adb <- get_roc(fit_adb, df_test)
perf_adb <- perf_pred_adb[[1]]
pred_adb <- perf_pred_adb[[2]]

# take AUC 
auc_adb <- round(unlist(slot(performance(pred_adb, measure = "auc"), "y.values")), 3)

# plot
plot(perf_adb, main = "AdaBoost ROC curve", col = "steelblue", lwd = 3)
abline(a = 0, b = 1, lwd = 3, lty = 2, col = 1)
legend(x = 0.7, y = 0.3, legend = paste0("AUC = ", auc_adb))

4.2.2 TPR, FPR vs Probability cutoff

At which probability cut-off, you’ll get TPR = 1.0?

# use pred_rf (pred_roc) object
plot(performance(pred_adb, measure = "tpr", x.measure = "cutoff"),
     col="steelblue", 
     ylab = "Rate", 
     xlab="Probability cutoff")

plot(performance(pred_adb, measure = "fpr", x.measure = "cutoff"), 
     add = T, col = "red")

legend(x = 0.6,y = 0.7, c("TPR (Recall)", "FPR (1-Spec)"), 
       lty = 1, col =c('steelblue', 'red'), bty = 'n', cex = 1, lwd = 2)

#abline(v = 0.1, lwd = 2, lty=6)

title("AdaBoost.M1")

4.2.3 Confusion matrix

pred_prob_adb <- predict(fit_adb, newdata = df_test, type = "prob")

# choose your cut-off
cutoff = 0.12

# turn probabilities into classes
pred_class_adb <- ifelse(pred_prob_adb$yes > cutoff, "yes", "no")

pred_class_adb <- as.factor(pred_class_adb)

cm_adb <- confusionMatrix(data = pred_class_adb, 
                reference = df_test$stroke,
                mode = "everything",
                positive = "yes")

cm_adb
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   no  yes
##        no    34    0
##        yes 1137   52
##                                           
##                Accuracy : 0.0703          
##                  95% CI : (0.0566, 0.0861)
##     No Information Rate : 0.9575          
##     P-Value [Acc > NIR] : 1               
##                                           
##                   Kappa : 0.0025          
##                                           
##  Mcnemar's Test P-Value : <2e-16          
##                                           
##             Sensitivity : 1.00000         
##             Specificity : 0.02904         
##          Pos Pred Value : 0.04373         
##          Neg Pred Value : 1.00000         
##               Precision : 0.04373         
##                  Recall : 1.00000         
##                      F1 : 0.08380         
##              Prevalence : 0.04252         
##          Detection Rate : 0.04252         
##    Detection Prevalence : 0.97220         
##       Balanced Accuracy : 0.51452         
##                                           
##        'Positive' Class : yes             
## 

5 Extreme Gradient Boosting: xgbTree

xgbTree has 7 parameters

5.1 Params tuning

fit_ctrl_kp10 <- trainControl(## 10-fold CV
                           method = "repeatedcv",
                           number = 10,
                           repeats = 50, 
                           allowParallel = T)

5.2 Training and validation

5.2.1 Kappa-optimized

10-fold CV

set.seed(121)

fit_xgb_kp <- train(stroke ~ ., 
                 data = df_train_smote, 
                 method = "xgbTree",
                 metric = "Kappa", 
                 trControl = fit_ctrl_kp10,
                 tuneGrid = expand.grid(
                   .nrounds = 50,
                   .max_depth = seq(2, 10, 1),
                   .eta = 0.3,
                   .gamma = c(0.005, 0.01, 0.015),
                   .colsample_bytree = 1,
                   .min_child_weight = 1,
                   .subsample = 3
                 ),
                 nthreads = 10,
                 verbose = FALSE,
                 verbosity = 0)

fit_xgb_kp$bestTune

5.3 Features importance

imp_vars_xgb <- varImp(fit_xgb_kp)

plot(imp_vars_xgb, main = "Variable Importance with XGB")

5.4 Testing

5.4.1 ROC curve

# calculate ROC
perf_pred_xgb <- get_roc(fit_xgb_kp, df_test)
perf_xgb <- perf_pred_xgb[[1]]
pred_xgb <- perf_pred_xgb[[2]]


# take AUC 
auc_xgb <- round(unlist(slot(performance(pred_xgb, measure = "auc"), "y.values")), 3)

# plot
plot(perf_xgb, main = "XGB ROC curve", col = "steelblue", lwd = 3)
abline(a = 0, b = 1, lwd = 3, lty = 2, col = 1)
legend(x = 0.7, y = 0.3, legend = paste0("AUC = ", auc_xgb))

5.4.2 TPR v FPR

# use pred_xgb object
plot(performance(pred_xgb, measure = "tpr", x.measure = "cutoff"),
     col = "steelblue", 
     ylab = "Rate", 
     xlab = "Probability cutoff")

plot(performance(pred_xgb, measure = "fpr", x.measure = "cutoff"), 
     add = T, col = "red")

legend(x = 0.6,y = 0.7, c("TPR (Recall)", "FPR (1-Spec)"), 
       lty = 1, col = c('steelblue', 'red'), bty = 'n', cex = 1, lwd = 2)

#abline(v = 0.1, lwd = 2, lty=6)

title("xgbTree")

5.4.3 Confusion matrix

pred_prob_xgb <- predict(fit_xgb_kp, newdata=df_test, type = "prob")

# choose your cut-off
cutoff = 0.12

# turn probabilities into classes
pred_class_xgb <- ifelse(pred_prob_xgb$yes > cutoff, "yes", "no")

pred_class_xgb <- as.factor(pred_class_xgb)

cm_xgb <- confusionMatrix(data = pred_class_xgb, 
                reference = df_test$stroke,
                mode = "everything",
                positive = "yes")

cm_xgb

6 Save the workspace


save.image("data/workspace.RData")